{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "import torch.nn.utils.prune as prune\n",
    "import torch.nn.functional as F\n",
    "import torchvision.models as models\n",
    "\n",
    "from fastai2.data.external import *\n",
    "from fastai2.data.transforms import *\n",
    "from fastai2.data.block import *\n",
    "from fastai2.vision.models.xresnet import *\n",
    "from fastai2.vision.core import *\n",
    "from fastai2.vision.data import *\n",
    "from fastai2.vision.augment import *\n",
    "\n",
    "from torch.nn.utils import prune\n",
    "from torch.utils.tensorboard import SummaryWriter\n",
    "from utils import prune_scheduler\n",
    "\n",
    "NUM_EPOCHS = 1\n",
    "\n",
    "from ResNetLearner import ResNetLearner\n",
    "from TransformerLearner import TransformerLearner\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    \"\"\"\n",
    "    learner = ResNetLearner()\n",
    "\n",
    "    for epoch in range(NUM_EPOCHS):\n",
    "        learner.train(epoch)\n",
    "        learner.eval(epoch)\n",
    "        learner.prune_weights(epoch)\n",
    "        from IPython import embed; embed()\n",
    "    \"\"\"\n",
    "\n",
    "    transformer_config = {\n",
    "        #\"ntokens\": len(TEXT.vocab.stoi),\n",
    "        \"ntoken\": 2000,\n",
    "        \"emsize\": 200,\n",
    "        \"nhid\": 200,\n",
    "        \"nlayers\": 2,\n",
    "        \"nhead\": 8,\n",
    "        \"dropout\": 0.2,\n",
    "    }\n",
    "\n",
    "    learner = TransformerLearner(model_config=transformer_config)\n",
    "\n",
    "    for epoch in range(1, NUM_EPOCHS + 1):\n",
    "        learner.train(epoch)\n",
    "        learner.eval(epoch)\n",
    "        learner.prune_weights(epoch)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "ename": "RuntimeError",
     "evalue": "Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-5-51193dd10d49>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mlearner\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m~/Workspace/libxsmm/samples/deeplearning/sparse_dnn/TransformerLearner.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(self, epoch)\u001b[0m\n\u001b[1;32m    141\u001b[0m             \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    142\u001b[0m             \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcrit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mview\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mntokens\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtargets\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 143\u001b[0;31m             \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    144\u001b[0m             \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclip_grad_norm_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparameters\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0.5\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    145\u001b[0m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptim\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/fastai2/lib/python3.7/site-packages/torch/tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph)\u001b[0m\n\u001b[1;32m    196\u001b[0m                 \u001b[0mproducts\u001b[0m\u001b[0;34m.\u001b[0m \u001b[0mDefaults\u001b[0m \u001b[0mto\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    197\u001b[0m         \"\"\"\n\u001b[0;32m--> 198\u001b[0;31m         \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    199\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    200\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/fastai2/lib/python3.7/site-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables)\u001b[0m\n\u001b[1;32m     98\u001b[0m     Variable._execution_engine.run_backward(\n\u001b[1;32m     99\u001b[0m         \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 100\u001b[0;31m         allow_unreachable=True)  # allow_unreachable flag\n\u001b[0m\u001b[1;32m    101\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    102\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mRuntimeError\u001b[0m: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time."
     ]
    }
   ],
   "source": [
    "learner.train(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Transformer(\n",
       "  (pos_encoder): PositionalEncoding(\n",
       "    (dropout): Dropout(p=0.2, inplace=False)\n",
       "  )\n",
       "  (transformer_encoder): TransformerEncoder(\n",
       "    (layers): ModuleList(\n",
       "      (0): TransformerEncoderLayer(\n",
       "        (self_attn): MultiheadAttention(\n",
       "          (out_proj): Linear(in_features=200, out_features=200, bias=True)\n",
       "        )\n",
       "        (linear1): Linear(in_features=200, out_features=200, bias=True)\n",
       "        (dropout): Dropout(p=0.2, inplace=False)\n",
       "        (linear2): Linear(in_features=200, out_features=200, bias=True)\n",
       "        (norm1): LayerNorm((200,), eps=1e-05, elementwise_affine=True)\n",
       "        (norm2): LayerNorm((200,), eps=1e-05, elementwise_affine=True)\n",
       "        (dropout1): Dropout(p=0.2, inplace=False)\n",
       "        (dropout2): Dropout(p=0.2, inplace=False)\n",
       "      )\n",
       "      (1): TransformerEncoderLayer(\n",
       "        (self_attn): MultiheadAttention(\n",
       "          (out_proj): Linear(in_features=200, out_features=200, bias=True)\n",
       "        )\n",
       "        (linear1): Linear(in_features=200, out_features=200, bias=True)\n",
       "        (dropout): Dropout(p=0.2, inplace=False)\n",
       "        (linear2): Linear(in_features=200, out_features=200, bias=True)\n",
       "        (norm1): LayerNorm((200,), eps=1e-05, elementwise_affine=True)\n",
       "        (norm2): LayerNorm((200,), eps=1e-05, elementwise_affine=True)\n",
       "        (dropout1): Dropout(p=0.2, inplace=False)\n",
       "        (dropout2): Dropout(p=0.2, inplace=False)\n",
       "      )\n",
       "    )\n",
       "  )\n",
       "  (encoder): Embedding(28785, 200)\n",
       "  (decoder): Linear(in_features=200, out_features=28785, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "learner.model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learner.model"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
